package org.zjvis.datascience.service;

import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import org.zjvis.datascience.common.constant.DatabaseConstant;
import org.zjvis.datascience.common.dto.*;
import org.zjvis.datascience.common.enums.TaskTypeEnum;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.model.ApiResult;
import org.zjvis.datascience.common.vo.FolderVO;
import org.zjvis.datascience.common.vo.JobStatusVO;
import org.zjvis.datascience.common.vo.MLModelVO;
import org.zjvis.datascience.common.util.RestTemplateUtil;
import org.zjvis.datascience.service.mapper.MLModelMapper;
import org.zjvis.datascience.service.mapper.FolderMapper;
import org.zjvis.datascience.service.mapper.TaskMapper;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;

import java.util.List;

import static org.zjvis.datascience.common.constant.DatabaseConstant.DATASOURCE_TYPE_POSTGRES;

/**
 * @description MLModel 机器学习算子 Service
 * @date 2021-11-23
 */
@Service
public class MLModelService {

    @Value("${postgres.username}")
    private String USER_GP;

    @Value("${postgres.password}")
    private String PASSWORD_GP;

    protected final static Logger logger = LoggerFactory.getLogger("MLModelService");

    @Autowired
    private RestTemplateUtil restTemplateUtil;

    @Autowired
    private MLModelMapper mlmodelMapper;

    @Autowired
    private FolderMapper folderMapper;

    @Autowired
    private TaskMapper taskMapper;

    @Lazy
    @Autowired
    private DatabaseService databaseService;

    private boolean isContainKey(String name, String searchKey) {
        String lowerName = name.toLowerCase();
        String lowerKey = searchKey.toLowerCase();
        return lowerName.contains(lowerKey);
    }

    public Long save(MLModelDTO model) {
        mlmodelMapper.save(model);
        return model.getId();
    }

    public void setInvisible(Long id){
        mlmodelMapper.setInvisible(id);
    }

    public boolean update(MLModelDTO model) {
        return mlmodelMapper.update(model);
    }

    public MLModelDTO queryMetricsById(Long id) {
        return mlmodelMapper.queryMetricsById(id);
    }

    public List<MLModelDTO> queryByProjectId(MLModelVO vo){
        return mlmodelMapper.queryByProjectId(vo);
    }

   // public List<MLModelDTO> queryByProject(Long projectId) {
//        return mlmodelMapper.queryByProjectId(projectId);
//    }

    public void updateTrainTime(Long id){
        mlmodelMapper.updateTrainTime(id);
    }

//    public void updateProgressId(TaskInstanceDTO applicationId){mlmodelMapper.updateProgressId(applicationId);}

    public Connection getGPConn() {
        Connection conn = null;
        DatabaseDTO databaseDTO = databaseService.queryById(DatabaseConstant.DEFAULT_DATASET_ID);
        try {
            Class.forName("org.postgresql.Driver");
            conn = DriverManager.getConnection(databaseDTO.getConnectionURL(DATASOURCE_TYPE_POSTGRES), USER_GP, PASSWORD_GP);
        } catch (SQLException | ClassNotFoundException e) {
            throw new RuntimeException(String.format("MLModelService get jdbc connection failed. since %s", e.getMessage()),e);
        }
        return conn;
    }

    public int dropTable(String tableName){
        Connection conn = getGPConn();
        if (conn != null) {
            String sql = String.format("drop table if exists %s", tableName);
            try {
                Statement stat = conn.createStatement();
                stat.execute(sql);
            } catch (Exception e) {
                e.printStackTrace();
                return -1;
            } finally {
                try {
                    conn.close();
                } catch (SQLException e) {
                    return -2;
                }
            }
        }
        return 1;
    }

    public void delete(Long id){
        String tableName = "ml_model.output_" + id;
        dropTable(tableName);
        mlmodelMapper.delete(id);
        //delete the output(training)
    }

    public List<MLModelDTO> queryByStatus(MLModelVO model){return mlmodelMapper.queryByStatus(model);}

//    public void updatePanelInfo(MLModelVO vo){
//        mlmodelMapper.updatePanelInfo(vo);
//    }

//    public ApiResult exec(String appArgs){
//        ApiResult body = restTemplateUtil.submitModelJob(appArgs);
//        return body;
//    }

    public String submitFlaskJob(JSONObject params) throws IOException {
        String body = restTemplateUtil.submitFlaskJob(params);
        return body;
    }

    public String beginTraining(JSONObject params) throws IOException{
        String flaskServer = restTemplateUtil.getFlaskServer();
        String url = String.format("%s/%s", flaskServer, params.getString("apiPath"));
        logger.warn("---------> url: {} ", url);
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);

        MultiValueMap<String, String> formDataMap = new LinkedMultiValueMap<String, String>();
        for (String keyStr : params.keySet()) {
            formDataMap.add(keyStr, params.getString(keyStr));
        }
        logger.warn("formData -> {}", formDataMap.toString());
        WebClient webClient = WebClient.create();
        Flux<String> flux = webClient.post().uri(url)
                .body(BodyInserters.fromFormData(formDataMap)).retrieve().bodyToFlux(String.class);

        flux.subscribe(string -> logger.info("________________training finished_______with result:_____" + string));
        //JSONObject resJson = JSONObject.parseObject(flux.blockFirst());
        //logger.info("------------> training finished" + resJson.toJSONString());
        return "training began";
    }

    public List<MLModelDTO> queryModelPanel(MLModelVO vo){return mlmodelMapper.queryModelPanel(vo);}
//    public boolean stoptraining(TaskInstanceDTO taskInstance){
//        //mlmodelMapper.updateStatus(id);
//        String appId = taskInstance.getApplicationId();
//        if (!restTemplateUtil.killJob(appId)) {
//            throw new DataScienceException("系统异常，请稍后重试");
//        }
//        boolean res = mlmodelMapper.updateStatus(taskInstane);
//        return res;
//    }

    public boolean killTraining(MLModelVO model){
        JobStatusVO jobstatus = mlmodelMapper.queryStatus(model.getId());
        String progressId = jobstatus.getFinalStatus();
        String state = jobstatus.getState();
        if (!isContainKey(state,"RUNNING")){
            return false;
        }
        if (progressId != null && progressId.contains("application")){
            if (!restTemplateUtil.killJob(progressId)) {
                throw new DataScienceException("系统异常，请稍后重试");
            }
            boolean res = mlmodelMapper.updateStatus(model);
        }
        else{
            boolean res = mlmodelMapper.updateStatus(model);
//            restTemplateUtil.submitModelJob(String.format("--id %s --exe kill",model.getId()));
        }
        return true;
    }

    public boolean updateStatus(MLModelVO model){
        mlmodelMapper.updateStatus(model);
        return true;
    }

    public Long modelDelAuth(Long modelId, Long pipelineId){
        if (null != pipelineId ){
            List<TaskDTO> tasks = taskMapper.queryByPipeline(pipelineId);
            for (TaskDTO task : tasks) {
                String data = task.getDataJson();
                JSONObject dataJson = JSONObject.parseObject(data);
                if (task.getType() == TaskTypeEnum.TASK_TYPE_MODEL.getVal() && dataJson.containsKey("modelId")){
                    Long modelIdInPipeline = dataJson.getLong("modelId");
                    if (modelId.equals(modelIdInPipeline)){
                        return modelId;
                    }
                }
            }
        }
        return 0L;
    }

    public JobStatusVO queryJobstatus(String appId, Long id){
        if (appId.equals("")){
            return mlmodelMapper.queryStatus(id);
        }
//        else{
//            return restTemplateUtil.queryJobStatus(appId,null);
//        }
        return new JobStatusVO();
    }

    public Long createFolder(FolderDTO folder) {
        folderMapper.save(folder);
        return folder.getId();
    }

    public void updateFolder(FolderDTO dto) {
        folderMapper.updateFolder(dto);
    }

    public List<MLModelDTO> queryFolder(FolderVO vo){
        return folderMapper.queryFolder(vo);
    }

    public List<FolderDTO> getFolders(MLModelVO vo){
        return folderMapper.getFolders(vo);
    }

    public FolderDTO getFolderById(MLModelVO vo){
        return folderMapper.getFolderById(vo);
    }

    public List<MLModelVO> queryModelByFolderId(Long folderId){ return folderMapper.queryModelByFolderId(folderId);}

    public Long queryNumRunning(Long userId){return mlmodelMapper.queryNumRunning(userId);}
}
